from sklearn.datasets import fetch_openml
from sklearn.utils import shuffle
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import normalize
import numpy as np
import pandas as pd 

import torch
import torchvision
from torchvision import datasets, transforms
from sklearn.cluster import KMeans
from collections import defaultdict

        

class load_notmnist_mnist_2:
    def __init__(self):       
        #  mnist
        batch_size = 1
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST('./data', train=True, download=True,
                   transform=transform)
        train_loader = torch.utils.data.DataLoader(dataset1, batch_size=batch_size,
                                      shuffle=True, num_workers=2)
        self.dataiter = iter(train_loader)
	self.n_arm = np.max(self.y_arm) + 1
        self.dim = self.X.shape[1] + 9

    def step(self):
        x, y = self.dataiter.next()
        d = x.numpy()[0]
        d = d.reshape(self.act_dim )
        target = y.item()
        X = np.zeros((self.n_arm, self.dim))
        for a in range(self.n_arm):
            X[a, a:a+
                self.act_dim] = d
        rwd = np.zeros(self.n_arm)
        #print(target)
        rwd[target] = 1
        return X, rwd  

    

    



    

    

class Bandit_multi:
    def __init__(self, name):
        # Fetch data
        if name == 'covertype':
            X, y = fetch_openml('covertype', version=3, return_X_y=True)
            X = pd.get_dummies(X)
            # print(X,y)
            # class: 1-7
            # avoid nan, set nan as -1
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'MagicTelescope':
            X, y = fetch_openml('MagicTelescope', version=1, return_X_y=True)
            # class: h, g
            # avoid nan, set nan as -1
            # print(X,y)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'shuttle':
            X, y = fetch_openml('shuttle', version=1, return_X_y=True)
            # avoid nan, set nan as -1
            # print(X,y)
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'adult':
            X, y = fetch_openml('adult', version=2, return_X_y=True)
            X = pd.get_dummies(X)
            # avoid nan, set nan as -1
            # print(X,y)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'mushroom':
            X, y = fetch_openml('mushroom', version=1, return_X_y=True)
            # print(X,y,X.info())
            X = pd.get_dummies(X)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
            # avoid nan, set nan as -1
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'fashion':
            X, y = fetch_openml('Fashion-MNIST', version=1, return_X_y=True)
            # print(X,y,X.info())
            # avoid nan, set nan as -1
            X[np.isnan(X)] = - 1
            X = normalize(X)
        else:
            raise RuntimeError('Dataset does not exist')
        # Shuffle data
        self.X, self.y = shuffle(X, y)
        # generate one_hot coding:
        self.y_arm = np.array(self.y.values).astype(np.int)
        # cursor and other variables
        self.cursor = 0
        self.size = self.y.shape[0]
        self.n_arm = int(np.max(self.y_arm)/2+1)
        self.dim = self.X.shape[1]  + self.n_arm
        self.act_dim = self.X.shape[1]
        self.num_user = np.max(self.y_arm)+1
        print(self.dim)
        print(self.n_arm)

    def step(self):
        if self.cursor > (len(self.X)-1):
            self.cursor = 0
    
        x = self.X[self.cursor]
        y = self.y_arm[self.cursor]
        target = int(y.item()/2.0)
        X_n = []
        for i in range(self.n_arm):
            front = np.zeros((1*i))
            back = np.zeros((1*(self.n_arm - i)))
            new_d = np.concatenate((front, x, back), axis=0)
            X_n.append(new_d)
        X_n = np.array(X_n)    
        rwd = np.zeros(self.n_arm)
        rwd[target] = 1
        self.cursor += 1
        return X_n, rwd


class load_emnist_letter_1d:
    def __init__(self, is_shuffle=True):
        # Fetch data
        batch_size = 1
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        trainset = torchvision.datasets.EMNIST(root='./data', split = "letters", train=True,
                                        download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
        self.dataiter = iter(trainloader)

        self.n_arm = 26
        self.num_zeros = 10
        self.num_class = 26
        self.num_user = 26
        self.dim = 28*28 + self.num_zeros*(self.num_class - 1)

        
        
    def step(self):
        x, y = self.dataiter.next()
        d = x.numpy()[0][0].reshape(28*28)
        target = y.item()-1
        X_n = []
        for i in range(self.n_arm):
            front = np.zeros((self.num_zeros*i))
            back = np.zeros((self.num_zeros*(self.num_class - i-1)))
            new_d = np.concatenate((front,  d, back), axis=0)
            X_n.append(new_d)
        X_n = np.array(X_n)    
        rwd = np.zeros(self.n_arm)
        rwd[target] = 1
        return X_n, rwd

        
